import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import altair as alt
import plotly
import plotly.express as px
import os
if not os.path.exists('img'):
os.makedirs('img')
from IPython.display import IFrame, display_html
def show_fig(fig, filename, width="100%", height=500):
plotly.offline.plot(fig, filename=filename, auto_open=False, auto_play=False)
display_html(IFrame(filename, height=height, width=width))
def show_fig2(filename, width="100%", height=500):
display_html(IFrame(filename, width=width, height=height))
Heatmaps are a common 2-dimensional colored grid that is used for visualizing intensity or level of a variable across levels of two other variables.
It is commonly used in bioinformatics to display expression levels of genes across samples, often in conjunction with some cluster analysis to put like genes/samples next to each other for visualization purposes
Heatmaps can be drawn using all the packages we have seen in this class.
We'll use average monthly temperatures in Washington, DC in 1971-2020, obtained from the National Weather Service (link).
.footnote[This data was extracted using the R package datapasta]
import pandas as pd
dc_weather = pd.read_csv('data/dc_weather.csv')
dc_weather.drop('Annual', axis=1, inplace=True) # Remove the 'Annual' column
dc_weather.set_index('Year', inplace=True)
dc_weather = dc_weather.T
dc_weather.index.name='Month'
fig, ax = plt.subplots(figsize = (15,5))
sns.heatmap(dc_weather, cmap="inferno", cbar_kws={'shrink':0.8, 'label': 'Degrees (F)'}, ax=ax);
ax.set_yticklabels(ax.get_yticklabels(), rotation = 0, horizontalalignment='right');
ax.set_xticklabels(ax.get_xticklabels(), rotation = 45, horizontalalignment='right');
ax.set_title('Average monthly temperature in Washington DC (1971-2020)', loc='left');
fig.savefig('img/temp_sns.html')
show_fig2('img/temp_sns.html')
fig, ax = plt.subplots(figsize=(15,5))
im= ax.imshow(dc_weather.values, cmap='inferno')
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.set_ylabel('Degrees (F)', rotation=90, va='bottom');
ax.set_xticks(np.arange(dc_weather.shape[1]))
ax.set_yticks(np.arange(dc_weather.shape[0]))
ax.set_xticklabels(dc_weather.columns.values, rotation=45, horizontalalignment='right')
ax.set_yticklabels(dc_weather.index.values);
ax.set_title('Average monthly temperature in Washington DC (1971-2020)', loc='left');
fig.savefig('img/temp_mpl.html')
show_fig2('img/temp_mpl.html')
fig = px.imshow(dc_weather, labels = dict(color='Temperature (F)'))
show_fig(fig, 'img/temp_plotly.html')
D = dc_weather.reset_index().melt(id_vars='Month', value_name='Avg Temp')
D['Month'] = pd.Categorical(D['Month'], categories = dc_weather.index)
alt.Chart(D).mark_rect().encode(
x = 'Year:N', # Make year nominal so it is treated as a discrete variable
y = alt.Y('Month:N', bin=False, sort=None),
color = alt.Color('Avg Temp:Q', scale = alt.Scale(scheme='inferno')),
tooltip = ['Month','Year','Avg Temp']
).save('img/temp_alt.html')
show_fig2('img/temp_alt.html')
In many contexts, especially bioinformatics, heatmaps are used to display similarities between units, using cluster analysis. Typically hierarchical clustering is used.
We will use a breast cancer data set, and look to see if there are individuals who have similar profiles across the variables recorded, and if that might be related to outcome.
from sklearn.datasets import load_breast_cancer
bc_data = load_breast_cancer()
data = pd.DataFrame(bc_data.data, columns = bc_data.feature_names)
data.head()
| mean radius | mean texture | mean perimeter | mean area | mean smoothness | mean compactness | mean concavity | mean concave points | mean symmetry | mean fractal dimension | radius error | texture error | perimeter error | area error | smoothness error | compactness error | concavity error | concave points error | symmetry error | fractal dimension error | worst radius | worst texture | worst perimeter | worst area | worst smoothness | worst compactness | worst concavity | worst concave points | worst symmetry | worst fractal dimension | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 17.99 | 10.38 | 122.80 | 1001.0 | 0.11840 | 0.27760 | 0.3001 | 0.14710 | 0.2419 | 0.07871 | 1.0950 | 0.9053 | 8.589 | 153.40 | 0.006399 | 0.04904 | 0.05373 | 0.01587 | 0.03003 | 0.006193 | 25.38 | 17.33 | 184.60 | 2019.0 | 0.1622 | 0.6656 | 0.7119 | 0.2654 | 0.4601 | 0.11890 |
| 1 | 20.57 | 17.77 | 132.90 | 1326.0 | 0.08474 | 0.07864 | 0.0869 | 0.07017 | 0.1812 | 0.05667 | 0.5435 | 0.7339 | 3.398 | 74.08 | 0.005225 | 0.01308 | 0.01860 | 0.01340 | 0.01389 | 0.003532 | 24.99 | 23.41 | 158.80 | 1956.0 | 0.1238 | 0.1866 | 0.2416 | 0.1860 | 0.2750 | 0.08902 |
| 2 | 19.69 | 21.25 | 130.00 | 1203.0 | 0.10960 | 0.15990 | 0.1974 | 0.12790 | 0.2069 | 0.05999 | 0.7456 | 0.7869 | 4.585 | 94.03 | 0.006150 | 0.04006 | 0.03832 | 0.02058 | 0.02250 | 0.004571 | 23.57 | 25.53 | 152.50 | 1709.0 | 0.1444 | 0.4245 | 0.4504 | 0.2430 | 0.3613 | 0.08758 |
| 3 | 11.42 | 20.38 | 77.58 | 386.1 | 0.14250 | 0.28390 | 0.2414 | 0.10520 | 0.2597 | 0.09744 | 0.4956 | 1.1560 | 3.445 | 27.23 | 0.009110 | 0.07458 | 0.05661 | 0.01867 | 0.05963 | 0.009208 | 14.91 | 26.50 | 98.87 | 567.7 | 0.2098 | 0.8663 | 0.6869 | 0.2575 | 0.6638 | 0.17300 |
| 4 | 20.29 | 14.34 | 135.10 | 1297.0 | 0.10030 | 0.13280 | 0.1980 | 0.10430 | 0.1809 | 0.05883 | 0.7572 | 0.7813 | 5.438 | 94.44 | 0.011490 | 0.02461 | 0.05688 | 0.01885 | 0.01756 | 0.005115 | 22.54 | 16.67 | 152.20 | 1575.0 | 0.1374 | 0.2050 | 0.4000 | 0.1625 | 0.2364 | 0.07678 |
In seaborn, the clustermap function takes care of the clustering for us. Note that we are scaling the rows so that they have mean 0 and variance 1, to enable a better view of the differences in patterns. We are also using the correlation metric (1 - correlation) to cluster rows and columns.
fig = sns.clustermap(data, standard_scale=1, metric='correlation', method='average',cmap='RdBu',);
fig.savefig('img/heatmap_cluster1.png')
show_fig2('img/heatmap_cluster1.png')
/Users/abhijit/opt/anaconda3/envs/biof440/lib/python3.8/site-packages/seaborn/matrix.py:649: UserWarning: Clustering large matrix with scipy. Installing `fastcluster` may give better performance.
We will add a column to the heatmap that color-codes the outcomes, so we can see if the clustering aligns with the outcomes.
color_dict = dict(zip(np.unique(bc_data.target), np.array(['g','skyblue'])))
target_df = pd.DataFrame({'target': bc_data.target})
row_colors = target_df.target.map(color_dict)
sns.clustermap(data, standard_scale=1, metric='correlation', method='average', cmap = 'RdBu',row_colors=row_colors);
<seaborn.matrix.ClusterGrid at 0x7fd1cb74f910>
measles = pd.read_csv('data/measles.csv')
measles['state'] = measles['state'].str.title().str.replace('.', ' ')
measles = measles.set_index('state')
measles.head()
<ipython-input-389-12b6c101491b>:2: FutureWarning: The default value of regex will change from True to False in a future version. In addition, single character regular expressions will*not* be treated as literal strings when regex=True.
| 1930 | 1931 | 1932 | 1933 | 1934 | 1935 | 1936 | 1937 | 1938 | 1939 | 1940 | 1941 | 1942 | 1943 | 1944 | 1945 | 1946 | 1947 | 1948 | 1949 | 1950 | 1951 | 1952 | 1953 | 1954 | 1955 | 1956 | 1957 | 1958 | 1959 | 1960 | 1961 | 1962 | 1963 | 1964 | 1965 | 1966 | 1967 | 1968 | 1969 | 1970 | 1971 | 1972 | 1973 | 1974 | 1975 | 1976 | 1977 | 1978 | 1979 | 1980 | 1981 | 1982 | 1983 | 1984 | 1985 | 1986 | 1987 | 1988 | 1989 | 1990 | 1991 | 1992 | 1993 | 1994 | 1995 | 1996 | 1997 | 1998 | 1999 | 2000 | 2001 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| state | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Alabama | 4389 | 8934 | 270 | 1735 | 15849 | 7214 | 572 | 620 | 13511 | 4381 | 3052 | 8696 | 3564 | 3865 | 7199 | 339 | 3986 | 3693 | 2058 | 11066 | 1503 | 3144 | 11878 | 2799 | 8451 | 2061 | 7117 | 9264 | 7664 | 3467 | 2075 | 2588 | 2379 | 1165 | 18140 | 2346 | 1813 | 1345 | 158 | 12 | 486 | 2086 | 142 | 19 | 21 | 5 | 0 | 79 | 60 | 94 | 22 | 0 | 2 | 0 | 0 | 0 | 0 | 4 | 0 | 14 | 4 | 0 | 0 | 1 | 0 | 0 | 0 | 6 | 0 | 0 | 0 | 0 |
| Alaska | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1487 | 536 | 2511 | 958 | 1259 | 1829 | 1456 | 1002 | 1406 | 1849 | 1169 | 215 | 641 | 148 | 4 | 22 | 73 | 40 | 7 | 5 | 1 | 0 | 11 | 49 | 2 | 18 | 4 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 2 | 0 | 6 | 1 | 7 | 0 | 1 | 0 |
| American Samoa | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| Arizona | 2107 | 2135 | 86 | 1261 | 1022 | 586 | 2378 | 3793 | 604 | 479 | 2078 | 2929 | 3813 | 1002 | 4641 | 332 | 3274 | 1464 | 4303 | 3795 | 2389 | 10454 | 2515 | 4934 | 4959 | 10086 | 6785 | 7659 | 9845 | 9067 | 4305 | 7868 | 6350 | 8878 | 6328 | 1593 | 5536 | 1022 | 243 | 561 | 1006 | 515 | 868 | 11 | 21 | 83 | 233 | 242 | 58 | 84 | 327 | 7 | 16 | 0 | 0 | 0 | 38 | 6 | 0 | 1 | 61 | 52 | 0 | 3 | 9 | 12 | 8 | 9 | 8 | 1 | 0 | 2 |
| Arkansas | 996 | 849 | 99 | 5438 | 7222 | 1518 | 107 | 322 | 6690 | 1724 | 1096 | 5474 | 4542 | 2867 | 3780 | 1098 | 3266 | 2414 | 3824 | 12313 | 1739 | 7625 | 2910 | 13079 | 2619 | 2983 | 8637 | 1358 | 3002 | 816 | 1596 | 1684 | 1438 | 1453 | 1108 | 1195 | 1379 | 1269 | 3 | 3 | 33 | 447 | 13 | 37 | 14 | 2 | 20 | 8 | 18 | 10 | 9 | 14 | 0 | 0 | 9 | 0 | 244 | 0 | 0 | 0 | 32 | 0 | 0 | 0 | 0 | 2 | 0 | 3 | 0 | 1 | 1 | 0 |
measles2 = measles.reset_index().melt(id_vars = 'state',var_name='Year', value_name='count') # Tidying the data
bl = measles2.groupby('state')['count'].sum()
ind = bl.argsort() # Find index order that makes the states ordered by total case
# Creating marginal frequency distributions
d1 = measles2.groupby('Year').sum().reset_index()
d2 = measles2.groupby('state').sum().reset_index().sort_values(by='count', ascending=False)
g = sns.JointGrid(ratio=8, height=10)
sns.heatmap(measles.iloc[ind[::-1],:], cmap='YlOrRd', linewidth=0.1,ax=g.ax_joint, cbar=False)
sns.barplot(x='count', y ='state', data=d2, color = 'yellow', ax = g.ax_marg_y)
sns.barplot(y='count', x = 'Year', data=d1, color='yellow', ax = g.ax_marg_x)
g.ax_joint.axvline(np.where(measles.columns=='1961'), linestyle='--')
g.ax_joint.set_xticks(np.arange(0,80,10))
g.ax_joint.set_xticklabels(np.arange(1930, 2005,10),rotation=45);
g.ax_joint.set_ylabel('');
g.ax_marg_x.text(31,100,'Vaccine introduced', horizontalalignment='center', fontsize='x-large');
g.fig.subplots_adjust(top=0.9)
g.fig.suptitle("Measles cases in the US, 1930-2001", fontsize='xx-large');
g.fig.savefig('img/measles_sns.png')
import plotly.graph_objects as go
from plotly.subplots import make_subplots
fig1 = make_subplots(rows=2, cols=2, column_widths=[0.8, 0.2], row_heights=[0.2,0.8], horizontal_spacing=0.05, vertical_spacing=0.05)
fig1.add_trace(
go.Heatmap(z = measles_px.values, x = measles_px.columns, y = measles_px.index, colorscale='ylorrd', showscale=False),
row=2, col=1
)
fig1.add_trace(
go.Bar(x = d1.Year, y = d1['count'].values, marker_color='orange')
)
fig1.add_trace(
go.Bar(x = d2['count'].values, y = d2.state, orientation='h', marker_color='orange'),
row=2,col=2,
)
fig1.add_shape(go.layout.Shape(type='line', x0=31, x1=31, y0 = -1, y1=56, line=dict(color='green', width=3, dash='dash')),row=2, col = 1, )
fig1.add_annotation(xref='x domain', yref='y domain', x = 0.45, y = 0.05, text = 'Vaccine available', showarrow=False, row=2, col=1,)
fig1.update_xaxes(showticklabels=False, row=1, col=1)
fig1.update_yaxes(showticklabels=False, row=2, col=2)
fig1.update_yaxes(row=2, col=1, autorange='reversed')
fig1.update_yaxes(row=2, col=2, autorange='reversed')
fig1.update_layout(showlegend=False, template = 'simple_white')
show_fig(fig1, 'img/measles_plotly.html')
show_fig2('img/measles_sns.png')
top_bar = alt.Chart(d1).mark_bar().encode(
x = alt.X('Year', axis = alt.Axis(labels=False, title=None)),
y = alt.Y('count', axis = alt.Axis(title=None, format='s')),
tooltip = ['Year', alt.Tooltip('count', title="N", format='.3s')],
).properties(width=600, height=200).interactive()
side_bar = alt.Chart(d2).mark_bar().encode(
y = alt.Y('state', sort=d2.state.values, axis = alt.Axis(labels=False, title=None)),
x = alt.X('count', axis = alt.Axis(title=None, format='s')),
tooltip = ['state',alt.Tooltip('count', title="N",format='.3s')],
).properties(width=200, height=600).interactive()
heatmap = alt.Chart(measles2).mark_rect().encode(
x = 'Year',
y = alt.Y('state', sort = d2.state.values),
color = alt.Color('count', scale=alt.Scale( scheme='yelloworangered')),
tooltip = [alt.Tooltip('state', title='State'),
alt.Tooltip('Year', title='Year'),
alt.Tooltip('count', title='N', format='.3s')],
).properties(width=600, height=600).interactive()
(top_bar & (heatmap | side_bar)).save('img/measles_alt.html')
show_fig2('img/measles_alt.html')